{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from pathlib import Path\n",
    "from sagemaker.predictor import json_serializer\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "role = sagemaker.get_execution_role()\n",
    "session = sagemaker.Session()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Path "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# location for train.csv, val.csv and labels.csv\n",
    "DATA_PATH = Path(\"../data/\")   \n",
    "\n",
    "# Location for storing training_config.json\n",
    "CONFIG_PATH = DATA_PATH/'config'\n",
    "CONFIG_PATH.mkdir(exist_ok=True)\n",
    "\n",
    "# S3 bucket name\n",
    "bucket = 'sagemaker-deep-learning'\n",
    "\n",
    "# Prefix for S3 bucket for input and output\n",
    "prefix = 'toxic_comments/input'\n",
    "prefix_output = 'toxic_comments/output'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters & Training Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters = {\n",
    "    \"epochs\": 10,\n",
    "    \"lr\": 8e-5,\n",
    "    \"max_seq_length\": 512,\n",
    "    \"train_batch_size\": 16,\n",
    "    \"lr_schedule\": \"warmup_cosine\",\n",
    "    \"warmup_steps\": 1000,\n",
    "    \"optimizer_type\": \"adamw\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_config = {\n",
    "    \"run_text\": \"toxic comments\",\n",
    "    \"finetuned_model\": None,\n",
    "    \"do_lower_case\": \"True\",\n",
    "    \"train_file\": \"train.csv\",\n",
    "    \"val_file\": \"val.csv\",\n",
    "    \"label_file\": \"labels.csv\",\n",
    "    \"text_col\": \"comment_text\",\n",
    "    \"label_col\": '[\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]',\n",
    "    \"multi_label\": \"True\",\n",
    "    \"grad_accumulation_steps\": \"1\",\n",
    "    \"fp16_opt_level\": \"O1\",\n",
    "    \"fp16\": \"True\",\n",
    "    \"model_type\": \"roberta\",\n",
    "    \"model_name\": \"roberta-base\",\n",
    "    \"logging_steps\": \"300\"\n",
    "}\n",
    "\n",
    "with open(CONFIG_PATH/'training_config.json', 'w') as f:\n",
    "    json.dump(training_config, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a helper feature to upload data\n",
    "# from your local machine to S3 bucket.\n",
    "\n",
    "s3_input = session.upload_data(DATA_PATH, bucket=bucket , key_prefix=prefix)\n",
    "\n",
    "session.upload_data(str(DATA_PATH/'labels.csv'), bucket=bucket , key_prefix=prefix)\n",
    "session.upload_data(str(DATA_PATH/'train.csv'), bucket=bucket , key_prefix=prefix)\n",
    "session.upload_data(str(DATA_PATH/'val.csv'), bucket=bucket , key_prefix=prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create an Estimator and start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "account = session.boto_session.client('sts').get_caller_identity()['Account']\n",
    "region = session.boto_session.region_name\n",
    "\n",
    "image = \"{}.dkr.ecr.{}.amazonaws.com/sagemaker-bert:1.0-gpu-py36\".format(account, region)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = \"s3://{}/{}\".format(bucket, prefix_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = sagemaker.estimator.Estimator(image, \n",
    "                                          role,\n",
    "                                          train_instance_count=1, \n",
    "                                          train_instance_type='ml.p3.8xlarge', \n",
    "                                          output_path=output_path, \n",
    "                                          base_job_name='toxic-comments',\n",
    "                                          hyperparameters=hyperparameters,\n",
    "                                          sagemaker_session=session\n",
    "                                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.fit(s3_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy the model to hosting service"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = estimator.deploy(1, \n",
    "                             'ml.m5.large', \n",
    "                             endpoint_name='bert-toxic-comments', \n",
    "                             update_endpoint=True, \n",
    "                             serializer=json_serializer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
